{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c17c940a-b7cf-4414-92bc-1dfa87e4b240",
   "metadata": {},
   "source": [
    "# Adaptive State Transition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d87d712d-283c-4683-b7aa-5195f851e811",
   "metadata": {},
   "source": [
    "We will be using Burgers' Equation to show this for PDEs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e33609-b37d-436d-b2b0-200577847ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import optax\n",
    "from flax import linen as nn\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import time\n",
    "\n",
    "# Add /src to path\n",
    "path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../../../src'))\n",
    "if path_to_src not in sys.path:\n",
    "    sys.path.append(path_to_src)\n",
    "\n",
    "from KAN import KAN\n",
    "from PIKAN import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dca1f133-8abf-44fe-9f98-2cf13efb29ed",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd634f87-0bea-47ea-834e-5b4b4afa40db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([0,1]), N)) # (64, 2)\n",
    "BC1_data = - jnp.sin(np.pi*BC1_colloc[:,1]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([0,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eefadd88-07ae-407c-88cb-d46dd35158f4",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7909b8a2-a88e-495c-9fc3-0be21e2a4c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameter\n",
    "    v = jnp.array(0.01/jnp.pi, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "\n",
    "    print(type(u))\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_t = gradf(u, 0, 1)  # 1st order derivative of t\n",
    "    u_x = gradf(u, 1, 1)  # 1st order derivative of x\n",
    "    u_xx = gradf(u, 1, 2) # 2nd order derivative of x\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_t(collocs) + u(collocs)*u_x(collocs) - v*u_xx(collocs)\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "645ec2fa-1e24-47f9-bb5c-0cb9d94a13b7",
   "metadata": {},
   "source": [
    "## Discontinuities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a58c2b06-fd40-4f88-a98d-0fc1fd7eb34e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### Full Update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bf038b5-1d07-4dcf-bd0e-86d3e23d54ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.003\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 4, 750 : 15}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be30eb3c-9594-4851-bfad-45f35954b791",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500\n",
    "\n",
    "model, variables, losses_full_upd = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=False, loc_w=None, nesterov=False, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=[], \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "137cdacd-b0ef-4096-b629-60f03d3d723a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(losses_full_upd), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90358ea4-ed07-4242-9e84-d91f7a56de9d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### Adaptation only, no extension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "922ec2ab-40cf-48d7-856c-61f7e8edc5c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.003\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 4}\n",
    "\n",
    "grid_adapt = [750]\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d9d735-eff0-413d-b9a3-2cc124d8ab84",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500\n",
    "\n",
    "model, variables, losses_adapt_upd = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=False, loc_w=None, nesterov=False, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2456b35b-edf0-42b3-b98a-d1a507c25b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(losses_adapt_upd), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d51966d3-436c-49e9-b332-d136aab3374b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### Optimizer Restart only"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d5542ee-47a3-410a-a84a-011c4e699d2f",
   "metadata": {},
   "source": [
    "For the following 3 cells to work we need to momentarily update the train_PIKAN function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5851f6a-8bb7-4f28-87ce-67b01c8077cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.003\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 4}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c8e69d4-6b84-4e26-877c-c5fbe01548c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500\n",
    "\n",
    "model, variables, losses_opt_upd = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=False, loc_w=None, nesterov=False, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=[], \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a496edd-7a35-40c9-a9b7-79d06bcc7878",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(losses_opt_upd), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "552cf784-2c6a-484e-bc21-5735a3efc37c",
   "metadata": {},
   "source": [
    "## Resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7832afd8-8bdb-4ab5-9571-f4577e0f0e2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.003\n",
    "lr_vals['scales'] = {0 : 1.0, 749: 0.3}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 101\n",
    "adapt_stop = 750\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "#grid_adapt = [200, 400, 600, 200, 400, 1000, 1500, 1800, 2100, 2700, 3500]\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 4, 750 : 15}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c1cecc2-d43d-4383-8814-4ff603f82941",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500\n",
    "\n",
    "model, variables, losses_full_res = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=None, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7adc28b-2c5d-484d-b3f0-3fae147fbb4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(losses_full_res), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba69df20-4a80-4c3f-93ae-905691d036c6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b56790ab-26eb-48ef-a3f5-120cd9fe4c84",
   "metadata": {},
   "source": [
    "### Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17943d4-95ff-4942-a7e5-a86a8ff23d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "grids = np.array(list(grid_extend.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a175c3-4fa1-4c51-b815-ea07c3e58ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run only when running the optimizer experiment\n",
    "np.savez('plch.npz', losses_opt_upd=losses_opt_upd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edc5f889-5139-43f5-ab61-3f27e388d10b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plch = np.load('plch.npz')\n",
    "losses_opt_upd = plch['losses_opt_upd']\n",
    "\n",
    "np.savez('../Plots/data/state.npz', epochs=epochs, grids=grids, losses_full_upd=losses_full_upd,\n",
    "         losses_adapt_upd=losses_adapt_upd, losses_opt_upd=losses_opt_upd, losses_full_res=losses_full_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f9285e1-7db1-4db0-841e-ebcc52f9d2c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ff3a5356-8cf4-4c41-8569-b10bae1e1f34",
   "metadata": {},
   "source": [
    "## Curve Fitting\n",
    "\n",
    "The effect of the smooth state transition, even without adaptation, is very profound for the problem of curve fitting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b37e6e8-49b1-4007-9f20-225a034a7897",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use pykan's methods for direct reference\n",
    "import torch\n",
    "\n",
    "from kan.utils import create_dataset\n",
    "\n",
    "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
    "\n",
    "dataset = create_dataset(f, n_var=4, train_num=3000)\n",
    "\n",
    "X_train = jnp.array(dataset['train_input'].numpy())\n",
    "X_test = jnp.array(dataset['test_input'].numpy())\n",
    "y_train = jnp.array(dataset['train_label'].numpy())\n",
    "y_test = jnp.array(dataset['test_label'].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35302e1a-dbde-4d34-aef0-28d8424b32a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "layer_dims = [4, 5, 2, 1]\n",
    "\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.02)\n",
    "variables = model.init(key, jnp.ones([10, 4]))\n",
    "\n",
    "# Define MSE\n",
    "def loss_fn(params, x, y, state):\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    # Forward pass to acquire predictions and spl_regs\n",
    "    preds, spl_regs = model.apply(variables, x)\n",
    "\n",
    "    # Define the prediction loss\n",
    "    loss_pred = jnp.sqrt(jnp.mean((preds-y)**2))\n",
    "    \n",
    "    return loss_pred\n",
    "\n",
    "@jax.jit\n",
    "def train_step(params, opt_state, x, y, state):\n",
    "    \n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params, x, y, state)\n",
    "    \n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    \n",
    "    new_variables = {'params': params, 'state': state}\n",
    "    \n",
    "    return new_variables, opt_state, loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4f3c345-8626-4c60-ab5b-17fd9fa36ec5",
   "metadata": {},
   "source": [
    "## Non-smooth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b338297f-04f0-4166-8e31-24b303663ef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize optimizer\n",
    "boundaries = [0, 200, 400, 600] # Epochs at which to change the grid & learning rate\n",
    "grid_vals = [3, 6, 10, 24] # Grid sizes to use\n",
    "# Corresponding dict\n",
    "grid_upds = dict(zip(boundaries, grid_vals))\n",
    "\n",
    "init_lr = 0.02\n",
    "optimizer = optax.adam(learning_rate=init_lr, nesterov=False)\n",
    "opt_state = optimizer.init(variables['params'])\n",
    "\n",
    "# Training epochs\n",
    "num_epochs = 800\n",
    "\n",
    "loss_non_smooth = []\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    # Check if we're in an update epoch\n",
    "    if epoch in grid_upds.keys():\n",
    "        print(f\"Epoch {epoch+1}: Performing grid update\")\n",
    "        # Get grid size\n",
    "        G_new = grid_upds[epoch]\n",
    "        # Perform the update\n",
    "        updated_variables = model.apply(variables, X_train, G_new, method=model.update_grids)\n",
    "        variables = updated_variables.copy()\n",
    "        # Re-initialize optimizer\n",
    "        opt_state = optimizer.init(variables['params'])\n",
    "        \n",
    "    # Calculate the loss\n",
    "    params, state = variables['params'], variables['state']\n",
    "    variables, opt_state, loss = train_step(params, opt_state, X_train, y_train, state)\n",
    "\n",
    "    loss_non_smooth.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3cca0ea-4cc2-48a2-ad39-3bb34cccf6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(loss_non_smooth), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d703e44-3747-4fff-b88c-3a6f2e9a555f",
   "metadata": {},
   "source": [
    "## Smooth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f34dd97-0486-4db5-92b5-99ed8b99f27a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize optimizer\n",
    "boundaries = [0, 200, 400, 600] # Epochs at which to change the grid & learning rate\n",
    "scales = [1.0, 0.5, 0.2, 0.5] # Learning rate scales\n",
    "grid_vals = [3, 6, 10, 24] # Grid sizes to use\n",
    "\n",
    "init_lr = 0.02\n",
    "\n",
    "# Corresponding dicts\n",
    "lr_scales = dict(zip(boundaries, scales))\n",
    "grid_upds = dict(zip(boundaries, grid_vals))\n",
    "\n",
    "# Create a piecewise constant schedule\n",
    "schedule = optax.piecewise_constant_schedule(\n",
    "    init_value=init_lr,\n",
    "    boundaries_and_scales=lr_scales\n",
    ")\n",
    "\n",
    "optimizer = optax.adam(learning_rate=schedule, nesterov=True)\n",
    "\n",
    "opt_state = optimizer.init(variables['params'])\n",
    "\n",
    "# Training epochs\n",
    "num_epochs = 800\n",
    "\n",
    "loss_smooth = []\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    # Check if we're in an update epoch\n",
    "    if epoch in grid_upds.keys():\n",
    "        print(f\"Epoch {epoch+1}: Performing grid update\")\n",
    "        # Get grid size\n",
    "        G_new = grid_upds[epoch]\n",
    "        # Perform the update\n",
    "        updated_variables = model.apply(variables, X_train, G_new, method=model.update_grids)\n",
    "        variables = updated_variables.copy()\n",
    "        # Re-initialize optimizer smoothly\n",
    "        opt_state = state_transition(opt_state, variables)\n",
    "        \n",
    "    # Calculate the loss\n",
    "    params, state = variables['params'], variables['state']\n",
    "    variables, opt_state, loss = train_step(params, opt_state, X_train, y_train, state)\n",
    "\n",
    "    loss_smooth.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17917df6-674f-4897-9c28-6167aac5e3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(loss_smooth), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481da5a3-f482-4603-98b9-13bbe621dd87",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9709c2f4-b385-4c42-a403-4fe9558ae79d",
   "metadata": {},
   "source": [
    "### Save Results for Appendix Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e79fdcc-0334-489d-8139-90b8592d4559",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "grids = np.array(list(grid_upds.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33bb421c-07cf-487d-98fa-10b1bb47eb8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('../Plots/data/state_curve.npz', epochs=epochs, grids=grids, with_a=loss_smooth, without_a=loss_non_smooth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae82aab-fc76-43e5-b5c8-7a4f29f51b18",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
